{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Iris dataset example\n",
    "\n",
    "This example showcases using K3D to visualize data. Addtional requirements: `scikit-learn`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-04-18T20:40:37.580348Z",
     "start_time": "2019-04-18T20:40:37.394862Z"
    }
   },
   "outputs": [],
   "source": [
    "try:\n",
    "    import sklearn\n",
    "except ImportError:\n",
    "    %pip install scikit-learn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-04-18T20:40:46.224530Z",
     "start_time": "2019-04-18T20:40:45.715673Z"
    }
   },
   "outputs": [],
   "source": [
    "from sklearn.datasets import load_iris\n",
    "from sklearn.manifold import TSNE\n",
    "\n",
    "from k3d import plot, points, nice_colors, text2d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-04-18T20:40:47.160075Z",
     "start_time": "2019-04-18T20:40:47.140711Z"
    }
   },
   "outputs": [],
   "source": [
    "iris = load_iris()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plots omitting each column once"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-04-18T20:40:49.492820Z",
     "start_time": "2019-04-18T20:40:49.483242Z"
    }
   },
   "outputs": [],
   "source": [
    "from k3d import plot, points, nice_colors, text2d\n",
    "import numpy as np\n",
    "\n",
    "def legend(p, axes):\n",
    "    \"\"\"Display classes' names in their color.\"\"\"\n",
    "    k3dplot = plot(axes=['\\:'.join(a.split(' ')) for a in axes.tolist()])\n",
    "    k3dplot += p\n",
    "    for i, name in enumerate(iris.target_names):\n",
    "        k3dplot += text2d(text=name, color=nice_colors[i], position=(0, i / 10))\n",
    "    return k3dplot\n",
    "\n",
    "def point_size(data, resolution=20.):\n",
    "    span = max(np.max(data, axis=0) - np.min(data, axis=0))\n",
    "    return span / resolution\n",
    "\n",
    "common = dict(\n",
    "    point_size=point_size(iris.data), \n",
    "    colors=[nice_colors[i] for i in iris.target]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-04-18T20:40:50.222684Z",
     "start_time": "2019-04-18T20:40:50.210636Z"
    }
   },
   "outputs": [],
   "source": [
    "def iris_omit(column_index):\n",
    "    names = np.roll(iris.feature_names, column_index)\n",
    "    print('omitting', names[0])\n",
    "    for dim, name in zip('xyz', names[1:]):\n",
    "        print(dim, 'is', name)\n",
    "    return (np.roll(iris.data, column_index, axis=1)[:, 1:].astype(np.float32), names[1:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-04-18T20:40:51.830949Z",
     "start_time": "2019-04-18T20:40:51.775439Z"
    },
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "p, axes = iris_omit(0)\n",
    "legend(points(p, **common), axes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-04-18T13:40:41.002294Z",
     "start_time": "2019-04-18T13:40:40.965961Z"
    },
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "p, axes = iris_omit(1)\n",
    "legend(points(p, **common), axes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-04-18T13:40:42.317941Z",
     "start_time": "2019-04-18T13:40:42.276879Z"
    },
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "p, axes = iris_omit(2)\n",
    "legend(points(p, **common), axes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-04-18T13:40:43.365839Z",
     "start_time": "2019-04-18T13:40:43.318006Z"
    },
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "p, axes = iris_omit(3)\n",
    "legend(points(p, **common), axes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3-D trainable Stochastic Neighbor Embedding\n",
    "\n",
    "This can look a little different every time around.\n",
    "\n",
    "**NOTE**: this is unsupervised learning, tSNE doesn't get the labels. And still, it clusters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-04-18T20:43:19.952715Z",
     "start_time": "2019-04-18T20:43:19.945843Z"
    }
   },
   "outputs": [],
   "source": [
    "tsne = TSNE(n_components=3, verbose=1, perplexity=40, n_iter=3000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-04-18T20:43:26.641354Z",
     "start_time": "2019-04-18T20:43:20.582832Z"
    }
   },
   "outputs": [],
   "source": [
    "tsne_results = tsne.fit_transform(iris.data).astype(np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-04-18T20:43:26.650496Z",
     "start_time": "2019-04-18T20:43:26.643330Z"
    }
   },
   "outputs": [],
   "source": [
    "tsne_results[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-04-18T20:43:39.427908Z",
     "start_time": "2019-04-18T20:43:39.399251Z"
    }
   },
   "outputs": [],
   "source": [
    "common['point_size'] = point_size(tsne_results)\n",
    "legend(points(tsne_results, **common), axes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
